# -*- coding:utf-8 -*-

import numpy as np
from utils.log_utils import log_debug

"""
数据方法
"""


class DictObj(object):
    """
    字典转对象
    """
    def __init__(self, map):
        self.map = map

    def __setattr__(self, name, value):
        if name == 'map':
            object.__setattr__(self, name, value)
            return
        log_debug('\033[30madd key={}, value={}\033[0m'.format(name, value))
        self.map[name] = value

    def __getattr__(self, name):
        if name in self.map:
            v = self.map[name]
            if name is 'xs_shape' or name is 'ys_shape':  # 形状的括号无需解析成层级
                return self.map[name]
            else:
                if isinstance(v, dict):
                    return DictObj(v)
                if isinstance(v, list):
                    r = []
                    for i in v:
                        r.append(DictObj(i))
                    return r
                else:
                    return self.map[name]
        else:
            return None

    def __getitem__(self, name):
        return self.map[name]


def train_vali_split(x, vali_size=0.3):
    """
    分割成训练集和验证集
    :param x:含label的数据集
    :param vali_size:测试集占整个数据集的比例
    :return:
    """
    x_num = x.shape[0]
    train_index = range(x_num)
    vali_index = []
    vali_num = int(x_num * vali_size)
    for i in range(vali_num):
        random_index = int(np.random.uniform(0, len(train_index)))
        vali_index.append(train_index[random_index])
        del train_index[random_index]
    # train,test的index是抽取的数据集X的序号
    train = x.ix[train_index]
    vali = x.ix[vali_index]
    return train, vali


def features_labels_split(arr, seq_size):
    """
    文本分成特征和样本
    :param arr:
    :param seq_size: 一个seq有多少字符
    :return:
    """
    # 取整, 否则reshape报错
    num = len(arr) // seq_size
    arr = arr[:num * seq_size]
    arr = arr.reshape((-1, seq_size))  # lines x seq_size
    features = []
    labels = []
    for n in range(arr.shape[0]):
        x = arr[n]
        features.append(x)
        y = np.zeros_like(x)
        y[:-1], y[-1] = x[1:], x[0]
        labels.append(y)
    return [np.array(features), np.array(labels), len(features)]


def dict_merge(new_dict, default):
    """
    merge 两个dict
    :param new_dict:
    :param default:
    :return:
    """
    if isinstance(new_dict, dict) and isinstance(default, dict):
        for k, v in default.items():
            if k not in new_dict:
                new_dict[k] = v
            else:
                new_dict[k] = dict_merge(new_dict[k], v)
    return new_dict
